{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b3aad2e0-3b9a-4761-ae52-12fb4f29143a",
   "metadata": {},
   "source": [
    "# Handling Class Imbalance with AutoMM - Focal Loss\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/focal_loss.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/focal_loss.ipynb)\n",
    "\n",
    "\n",
    "In this tutorial, we introduce how to use focal loss with the AutoMM package for balanced training.\n",
    "Focal loss is first introduced in this [Paper](https://arxiv.org/abs/1708.02002)\n",
    "and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use focal loss."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a764799-e835-420e-b51e-81dcfbda884c",
   "metadata": {},
   "source": [
    "## Create Dataset\n",
    "We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa00faab-252f-44c9-b8f7-57131aa8251c",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "!pip install autogluon.multimodal\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c9c2134-bc31-4611-ab58-31701d8afdbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal.utils.misc import shopee_dataset\n",
    "\n",
    "download_dir = \"./ag_automm_tutorial_imgcls_focalloss\"\n",
    "train_data, test_data = shopee_dataset(download_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6965f254-5d25-4ee0-a58f-3c5b275d4ae9",
   "metadata": {},
   "source": [
    "For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee \n",
    "training data to form an imbalanced distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918c3f82-06da-45c6-af8f-e847bbbc3e79",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "ds = 1\n",
    "\n",
    "imbalanced_train_data = []\n",
    "for lb in range(4):\n",
    "    class_data = train_data[train_data.label == lb]\n",
    "    sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)\n",
    "    ds /= 3  # downsample 1/3 each time for each class\n",
    "    imbalanced_train_data.append(class_data.iloc[sample_index])\n",
    "imbalanced_train_data = pd.concat(imbalanced_train_data)\n",
    "print(imbalanced_train_data)\n",
    "\n",
    "weights = []\n",
    "for lb in range(4):\n",
    "    class_data = imbalanced_train_data[imbalanced_train_data.label == lb]\n",
    "    weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))\n",
    "    print(f\"class {lb}: num samples {len(class_data)}\")\n",
    "weights = list(np.array(weights) / np.sum(weights))\n",
    "print(weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ee803f4-06be-4e39-8347-4fe57992f2f8",
   "metadata": {},
   "source": [
    "## Create and train `MultiModalPredictor`\n",
    "\n",
    "### Train with Focal Loss\n",
    "We specify the model to use focal loss by setting the `\"optimization.loss_function\"` to `\"focal_loss\"`.\n",
    "There are also three other optional parameters you can set.\n",
    "\n",
    "`optimization.focal_loss.alpha` - a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes.\n",
    "Note that the `len` of the list ***must*** match the total number of classes in the training dataset. A good way to compute `alpha` for each class is to use the inverse of its percentage number of samples.\n",
    "\n",
    "`optimization.focal_loss.gamma` - float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.\n",
    "\n",
    "`optimization.focal_loss.reduction` - how to aggregate the loss value. Can only take `\"mean\"` or `\"sum\"` for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb4bb3f-47ee-4772-bd3b-3ba00c76cd60",
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "model_path = f\"./tmp/{uuid.uuid4().hex}-automm_shopee_focal\"\n",
    "\n",
    "predictor = MultiModalPredictor(label=\"label\", problem_type=\"multiclass\", path=model_path)\n",
    "\n",
    "predictor.fit(\n",
    "    hyperparameters={\n",
    "        \"model.mmdet_image.checkpoint_name\": \"swin_tiny_patch4_window7_224\",\n",
    "        \"env.num_gpus\": 1,\n",
    "        \"optimization.loss_function\": \"focal_loss\",\n",
    "        \"optimization.focal_loss.alpha\": weights,  # shopee dataset has 4 classes.\n",
    "        \"optimization.focal_loss.gamma\": 1.0,\n",
    "        \"optimization.focal_loss.reduction\": \"sum\",\n",
    "        \"optimization.max_epochs\": 10,\n",
    "    },\n",
    "    train_data=imbalanced_train_data,\n",
    ") \n",
    "\n",
    "predictor.evaluate(test_data, metrics=[\"acc\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0ae5688-e0f5-4a62-a36f-87ce772ea8ef",
   "metadata": {},
   "source": [
    "### Train without Focal Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fadf7437-53b7-454d-81e0-596732f355a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "from autogluon.multimodal import MultiModalPredictor\n",
    "\n",
    "model_path = f\"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal\"\n",
    "\n",
    "predictor2 = MultiModalPredictor(label=\"label\", problem_type=\"multiclass\", path=model_path)\n",
    "\n",
    "predictor2.fit(\n",
    "    hyperparameters={\n",
    "        \"model.mmdet_image.checkpoint_name\": \"swin_tiny_patch4_window7_224\",\n",
    "        \"env.num_gpus\": 1,\n",
    "        \"optimization.max_epochs\": 10,\n",
    "    },\n",
    "    train_data=imbalanced_train_data,\n",
    ")\n",
    "\n",
    "predictor2.evaluate(test_data, metrics=[\"acc\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddb15b82-3af8-4837-8f2e-654dcc561e9d",
   "metadata": {},
   "source": [
    "As we can see that the model with focal loss is able to achieve a much better performance compared to the model without focal loss.\n",
    "When your data is imbalanced, try out focal loss to see if it brings improvements to the performance!\n",
    "\n",
    "## Citations\n",
    "\n",
    "```\n",
    "@misc{https://doi.org/10.48550/arxiv.1708.02002,\n",
    "  doi = {10.48550/ARXIV.1708.02002},\n",
    "  \n",
    "  url = {https://arxiv.org/abs/1708.02002},\n",
    "  \n",
    "  author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},\n",
    "  \n",
    "  keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},\n",
    "  \n",
    "  title = {Focal Loss for Dense Object Detection},\n",
    "  \n",
    "  publisher = {arXiv},\n",
    "  \n",
    "  year = {2017},\n",
    "  \n",
    "  copyright = {arXiv.org perpetual, non-exclusive license}\n",
    "}\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32fa55f-8a21-41dd-bcb9-b4b9922d7d61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
